{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# ES Tutorial with Fiber\n",
        "\n",
        "Jiale Zhi\n",
        "\n",
        "June 2020\n",
        "\n",
        "[Introduction to Fiber](https://eng.uber.com/fiberdistributed/)\n",
        "\n",
        "### Overview\n",
        "\n",
        "In this notebook, we will through:\n",
        "\n",
        "* how to install Fiber\n",
        "* how to implement parallel ES with Fiber\n",
        "* how to run parallel ES locally\n",
        "* how to run parallel ES on a remote cluster\n",
        "\n",
        "### Installing dependencies"
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install fiber numpy"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: fiber in /Users/jiale/Uber/fiber-github (0.2.1.dev0)\n",
            "Requirement already satisfied: numpy in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (1.16.5)\n",
            "Requirement already satisfied: nnpy-bundle in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (1.4.2.post1)\n",
            "Requirement already satisfied: psutil in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (5.6.5)\n",
            "Requirement already satisfied: docker in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (4.1.0)\n",
            "Requirement already satisfied: cloudpickle in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (1.2.2)\n",
            "Requirement already satisfied: kubernetes in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (10.0.1)\n",
            "Requirement already satisfied: requests in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (2.22.0)\n",
            "Requirement already satisfied: click in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from fiber) (7.0)\n",
            "Requirement already satisfied: cffi in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from nnpy-bundle->fiber) (1.13.2)\n",
            "Requirement already satisfied: websocket-client>=0.32.0 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from docker->fiber) (0.56.0)\n",
            "Requirement already satisfied: six>=1.4.0 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from docker->fiber) (1.14.0)\n",
            "Requirement already satisfied: python-dateutil>=2.5.3 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (2.8.1)\n",
            "Requirement already satisfied: pyyaml>=3.12 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (5.2)\n",
            "Requirement already satisfied: google-auth>=1.0.1 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (1.10.0)\n",
            "Requirement already satisfied: urllib3>=1.24.2 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (1.25.7)\n",
            "Requirement already satisfied: setuptools>=21.0.0 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (41.6.0.post20191030)\n",
            "Requirement already satisfied: certifi>=14.05.14 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (2020.4.5.1)\n",
            "Requirement already satisfied: requests-oauthlib in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from kubernetes->fiber) (1.3.0)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from requests->fiber) (3.0.4)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from requests->fiber) (2.8)\n",
            "Requirement already satisfied: pycparser in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from cffi->nnpy-bundle->fiber) (2.19)\n",
            "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from google-auth>=1.0.1->kubernetes->fiber) (4.0.0)\n",
            "Requirement already satisfied: rsa<4.1,>=3.1.4 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from google-auth>=1.0.1->kubernetes->fiber) (4.0)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from google-auth>=1.0.1->kubernetes->fiber) (0.2.7)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from requests-oauthlib->kubernetes->fiber) (3.1.0)\n",
            "Requirement already satisfied: pyasn1>=0.1.3 in /Users/jiale/miniconda3/envs/fiber/lib/python3.6/site-packages (from rsa<4.1,>=3.1.4->google-auth>=1.0.1->kubernetes->fiber) (0.4.8)\n"
          ]
        }
      ],
      "execution_count": 24,
      "metadata": {
        "scrolled": false,
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:12.957Z",
          "iopub.execute_input": "2020-07-08T06:08:12.963Z",
          "iopub.status.idle": "2020-07-08T06:08:14.632Z",
          "shell.execute_reply": "2020-07-08T06:08:14.638Z"
        }
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Import libraries"
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "import fiber\n",
        "import functools\n",
        "import numpy as np"
      ],
      "outputs": [],
      "execution_count": 25,
      "metadata": {
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:17.648Z",
          "iopub.execute_input": "2020-07-08T06:08:17.652Z",
          "iopub.status.idle": "2020-07-08T06:08:17.660Z",
          "shell.execute_reply": "2020-07-08T06:08:17.664Z"
        }
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Parallized Evolution Strategies\n",
        "\n",
        "![Parallel ES](https://github.com/uber/fiber/raw/docs/gecco-2020/tutorials/imgs/parallel_es.png)\n",
        "\n",
        "[source](https://arxiv.org/abs/1703.03864)"
      ],
      "metadata": {}
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Define target function"
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "solution = np.array([5.0, -5.0, 1.5])\n",
        "\n",
        "def F(theta):\n",
        "    return -np.sum(np.square(theta - solution))"
      ],
      "outputs": [],
      "execution_count": 28,
      "metadata": {
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:44.414Z",
          "iopub.execute_input": "2020-07-08T06:08:44.418Z",
          "iopub.status.idle": "2020-07-08T06:08:44.426Z",
          "shell.execute_reply": "2020-07-08T06:08:44.430Z"
        }
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Define optimization worker function"
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "def worker(dim, sigma, theta):\n",
        "    epsilon = np.random.rand(dim)\n",
        "    return F(theta + sigma * epsilon), epsilon"
      ],
      "outputs": [],
      "execution_count": 29,
      "metadata": {
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:47.788Z",
          "iopub.execute_input": "2020-07-08T06:08:47.793Z",
          "iopub.status.idle": "2020-07-08T06:08:47.800Z",
          "shell.execute_reply": "2020-07-08T06:08:47.805Z"
        }
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Define ES main loop"
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "def es(theta0, worker, workers=40, sigma=0.1, alpha=0.05, iterations=200):\n",
        "    dim = theta0.shape[0]\n",
        "    theta = theta0\n",
        "    pool = fiber.Pool(workers)\n",
        "    func = functools.partial(worker, dim, sigma)\n",
        "\n",
        "    for t in range(iterations):\n",
        "        returns = pool.map(func, [theta] * workers)\n",
        "        rewards = [ret[0] for ret in returns]\n",
        "        epsilons = [ret[1] for ret in returns]\n",
        "        # normalize rewards\n",
        "        normalized_rewards = (rewards - np.mean(rewards)) / np.std(rewards)\n",
        "        theta = theta + alpha * 1.0 / (workers * sigma) * sum(\n",
        "            [reward * epsilon for reward, epsilon in zip(normalized_rewards, epsilons)]\n",
        "        )\n",
        "        if t % 10 == 0:\n",
        "            print(theta)\n",
        "    return theta"
      ],
      "outputs": [],
      "execution_count": 30,
      "metadata": {
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:50.382Z",
          "iopub.execute_input": "2020-07-08T06:08:50.387Z",
          "iopub.status.idle": "2020-07-08T06:08:50.394Z",
          "shell.execute_reply": "2020-07-08T06:08:50.398Z"
        }
      }
    },
    {
      "cell_type": "code",
      "source": [
        "theta0 = np.random.rand(3)\n",
        "print(theta0)"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[0.86671061 0.73848431 0.24179934]\n"
          ]
        }
      ],
      "execution_count": 31,
      "metadata": {
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:51.531Z",
          "iopub.execute_input": "2020-07-08T06:08:51.536Z",
          "iopub.status.idle": "2020-07-08T06:08:51.546Z",
          "shell.execute_reply": "2020-07-08T06:08:51.550Z"
        }
      }
    },
    {
      "cell_type": "code",
      "source": [
        "result = es(theta0, worker)\n",
        "print(\"Result\", result)"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[0.95252593 0.62140418 0.29058771]\n",
            "[ 1.84840603 -0.58219734  0.54492689]\n",
            "[ 2.58169555 -1.68238933  0.84204997]\n",
            "[ 3.38070016 -2.85486794  1.11552064]\n",
            "[ 4.15657906 -4.01036158  1.25570931]\n",
            "[ 4.9727053  -5.06077535  1.47203385]\n",
            "[ 4.99440972 -4.97471383  1.41825278]\n",
            "[ 4.97002407 -5.01824378  1.41562208]\n",
            "[ 4.99776445 -5.09938847  1.40213522]\n",
            "[ 4.98058378 -5.08281907  1.48028614]\n",
            "[ 4.93052776 -5.11199241  1.43804138]\n",
            "[ 4.96548663 -5.11753547  1.40724279]\n",
            "[ 4.92274941 -4.975219    1.38152667]\n",
            "[ 4.90504796 -4.99724384  1.44838095]\n",
            "[ 4.93865843 -5.00612834  1.50314035]\n",
            "[ 4.89655864 -5.04379244  1.50655411]\n",
            "[ 5.00185022 -5.04523329  1.49979646]\n",
            "[ 4.93656231 -5.06518668  1.41717149]\n",
            "[ 4.98610681 -4.98321875  1.44649979]\n",
            "[ 4.90571602 -5.0343561   1.47023598]\n",
            "Result [ 4.98424945 -5.08676285  1.47894592]\n"
          ]
        }
      ],
      "execution_count": 32,
      "metadata": {
        "execution": {
          "iopub.status.busy": "2020-07-08T06:08:52.736Z",
          "iopub.execute_input": "2020-07-08T06:08:52.742Z",
          "iopub.status.idle": "2020-07-08T06:08:56.603Z",
          "shell.execute_reply": "2020-07-08T06:08:56.609Z"
        }
      }
    },
    {
      "cell_type": "code",
      "source": [],
      "outputs": [],
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "jupyter": {
          "source_hidden": false,
          "outputs_hidden": false
        },
        "nteract": {
          "transient": {
            "deleting": false
          }
        }
      }
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.6.9",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    },
    "nteract": {
      "version": "0.24.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}